<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8" >

<title>Pytorch之weight.data | KALUO</title>
<meta name="description" content="温故而知新">

<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">

<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.7.2/css/all.css" integrity="sha384-fnmOCqbTlWIlj8LyTjo7mOUStjsKC4pOpQbqyi7RrhN7udi9RwhKkMHpvLbHG9Sr" crossorigin="anonymous">
<link rel="shortcut icon" href="https://kaluo-zz.github.io/favicon.ico?v=1591084625855">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.10.0/katex.min.css">
<link rel="stylesheet" href="https://kaluo-zz.github.io/styles/main.css">



<script src="https://cdn.jsdelivr.net/npm/vue/dist/vue.js"></script>
<script src="https://cdn.bootcss.com/highlight.js/9.12.0/highlight.min.js"></script>

<link rel="stylesheet" href="https://unpkg.com/aos@next/dist/aos.css" />



  </head>
  <body>
    <div id="app" class="main">

      <div class="sidebar" :class="{ 'full-height': menuVisible }">
  <div class="top-container" data-aos="fade-right">
    <div class="top-header-container">
      <a class="site-title-container" href="https://kaluo-zz.github.io">
        <img src="https://kaluo-zz.github.io/images/avatar.png?v=1591084625855" class="site-logo">
        <h1 class="site-title">KALUO</h1>
      </a>
      <div class="menu-btn" @click="menuVisible = !menuVisible">
        <div class="line"></div>
      </div>
    </div>
    <div>
      
        
          <a href="/" class="site-nav">
            首页
          </a>
        
      
        
          <a href="/archives" class="site-nav">
            归档
          </a>
        
      
        
          <a href="/tags" class="site-nav">
            标签
          </a>
        
      
        
          <a href="/post/about" class="site-nav">
            关于
          </a>
        
      
    </div>
  </div>
  <div class="bottom-container" data-aos="flip-up" data-aos-offset="0">
    <div class="social-container">
      
        
      
        
      
        
      
        
      
        
      
    </div>
    <div class="site-description">
      温故而知新
    </div>
    <div class="site-footer">
      Powered by <a href="https://github.com/getgridea/gridea" target="_blank">Gridea</a> | <a class="rss" href="https://kaluo-zz.github.io/atom.xml" target="_blank">RSS</a>
    </div>
  </div>
</div>


      <div class="main-container">
        <div class="content-container" data-aos="fade-up">
          <div class="post-detail">
            <h2 class="post-title">Pytorch之weight.data</h2>
            <div class="post-date">2019-08-09</div>
            
              <div class="feature-container" style="background-image: url('https://kaluo-zz.github.io/post-images/pytorch-zhi-weightdata.jpg')">
              </div>
            
            <div class="post-content">
              <p><code>conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)</code></p>
<p>pytorch中的卷积核权重保存在conv.weight.data中，可以通过<code>.cpu().numpy()</code>将其转换为numpy形式，尺寸为<code>[out_channels, in_channels, kernel_h, kernel_w]</code>。</p>
<p>也可以将numpy形式的权重赋给卷积核，只要尺寸保持一致：</p>
<p><code>conv.weight.data = torch.from_numpy(new_weights).cuda()</code></p>
<p>对于<code>torch.nn.Linear</code>也是类似的。</p>
<p>给定一个模型<code>net</code>，如何获取它的权重呢？</p>
<pre><code>net = Net()

# 第一个卷积核的weight和bias权重信息
print(list(net.state_dict().items())[0][0]) # weight的名称
print(list(net.state_dict().items())[0][1]) # weight的值
# 或
# print(list(net.state_dict().items())[0][1].data)
print(list(net.state_dict().items())[0][1].shape) # weight的尺寸
# 或
# print(list(net.state_dict().items())[0][1].data.shape)

print(list(net.state_dict().items())[1][0]) # bias的名称
print(list(net.state_dict().items())[1][1]) # bias的值
print(list(net.state_dict().items())[1][1].shape) # bias的尺寸

</code></pre>
<p>输出：</p>
<pre><code>frontend.0.weight
tensor([[[[-5.5373e-01,  1.4270e-01,  5.2896e-01],
          [-5.8312e-01,  3.5655e-01,  7.6566e-01],
          [-6.9022e-01, -4.8019e-02,  4.8409e-01]],

         [[ 1.7548e-01,  9.8630e-03, -8.1413e-02],
          [ 4.4089e-02, -7.0323e-02, -2.6035e-01],
          [ 1.3239e-01, -1.7279e-01, -1.3226e-01]],

         [[ 3.1303e-01, -1.6591e-01, -4.2752e-01],
          [ 4.7519e-01, -8.2677e-02, -4.8700e-01],
          [ 6.3203e-01,  1.9308e-02, -2.7753e-01]]],


        [[[ 2.3254e-01,  1.2666e-01,  1.8605e-01],
          [-4.2805e-01, -2.4349e-01,  2.4628e-01],
          [-2.5066e-01,  1.4177e-01, -5.4864e-03]],

         [[-1.4076e-01, -2.1903e-01,  1.5041e-01],
          [-8.4127e-01, -3.5176e-01,  5.6398e-01],
          [-2.4194e-01,  5.1928e-01,  5.3915e-01]],

         [[-3.1432e-01, -3.7048e-01, -1.3094e-01],
          [-4.7144e-01, -1.5503e-01,  3.4589e-01],
          [ 5.4384e-02,  5.8683e-01,  4.9580e-01]]],


        [[[ 1.7715e-01,  5.2149e-01,  9.8740e-03],
          [-2.7185e-01, -7.1709e-01,  3.1292e-01],
          [-7.5753e-02, -2.2079e-01,  3.3455e-01]],

         [[ 3.0924e-01,  6.7071e-01,  2.0546e-02],
          [-4.6607e-01, -1.0697e+00,  3.3501e-01],
          [-8.0284e-02, -3.0522e-01,  5.4460e-01]],

         [[ 3.1572e-01,  4.2335e-01, -3.4976e-01],
          [ 8.6354e-02, -4.6457e-01,  1.1803e-02],
          [ 1.0483e-01, -1.4584e-01, -1.5765e-02]]],


        ...,


        [[[ 7.7599e-02,  1.2692e-01,  3.2305e-02],
          [ 2.2131e-01,  2.4681e-01, -4.6637e-02],
          [ 4.6407e-02,  2.8246e-02,  1.7528e-02]],

         [[-1.8327e-01, -6.7425e-02, -7.2120e-03],
          [-4.8855e-02,  7.0427e-03, -1.2883e-01],
          [-6.4601e-02, -6.4566e-02,  4.4235e-02]],

         [[-2.2547e-01, -1.1931e-01, -2.3425e-02],
          [-9.9171e-02, -1.5143e-02,  9.5385e-04],
          [-2.6137e-02,  1.3567e-03,  1.4282e-01]]],


        [[[ 1.6520e-02, -3.2225e-02, -3.8450e-03],
          [-6.8206e-02, -1.9445e-01, -1.4166e-01],
          [-6.9528e-02, -1.8340e-01, -1.7422e-01]],

         [[ 4.2781e-02, -6.7529e-02, -7.0309e-03],
          [ 1.1765e-02, -1.4958e-01, -1.2361e-01],
          [ 1.0205e-02, -1.0393e-01, -1.1742e-01]],

         [[ 1.2661e-01,  8.5046e-02,  1.3066e-01],
          [ 1.7585e-01,  1.1288e-01,  1.1937e-01],
          [ 1.4656e-01,  9.8892e-02,  1.0348e-01]]],


        [[[ 3.2176e-02, -1.0766e-01, -2.6388e-01],
          [ 2.7957e-01, -3.7416e-02, -2.5471e-01],
          [ 3.4872e-01,  3.0041e-02, -5.5898e-02]],

         [[ 2.5063e-01,  1.5543e-01, -1.7432e-01],
          [ 3.9255e-01,  3.2306e-02, -3.5191e-01],
          [ 1.9299e-01, -1.9898e-01, -2.9713e-01]],

         [[ 4.6032e-01,  4.3399e-01,  2.8352e-01],
          [ 1.6341e-01, -5.8165e-02, -1.9196e-01],
          [-1.9521e-01, -4.5630e-01, -4.2732e-01]]]])
torch.Size([64, 3, 3, 3])
frontend.0.bias
tensor([ 0.4034,  0.3778,  0.4644, -0.3228,  0.3940, -0.3953,  0.3951, -0.5496,
         0.2693, -0.7602, -0.3508,  0.2334, -1.3239, -0.1694,  0.3938, -0.1026,
         0.0460, -0.6995,  0.1549,  0.5628,  0.3011,  0.3425,  0.1073,  0.4651,
         0.1295,  0.0788, -0.0492, -0.5638,  0.1465, -0.3890, -0.0715,  0.0649,
         0.2768,  0.3279,  0.5682, -1.2640, -0.8368, -0.9485,  0.1358,  0.2727,
         0.1841, -0.5325,  0.3507, -0.0827, -1.0248, -0.6912, -0.7711,  0.2612,
         0.4033, -0.4802, -0.3066,  0.5807, -1.3325,  0.4844, -0.8160,  0.2386,
         0.2300,  0.4979,  0.5553,  0.5230, -0.2182,  0.0117, -0.5516,  0.2108])
torch.Size([64])
</code></pre>
<p><code>net.state_dict().items()</code>返回一个迭代器，可以用for循环遍历权重：<code>for k, v in net.state_dict().items()</code>。其中，<code>k</code>是权重名称，<code>v</code>是其值。</p>
<p>可以使用<code>list</code>将其转换为list：<code>list(net.state_dict().items())</code>。</p>

            </div>
            
              <div class="tag-container">
                
                  <a href="https://kaluo-zz.github.io/tag/qNRCu_hfZ" class="tag">
                    data布局
                  </a>
                
                  <a href="https://kaluo-zz.github.io/tag/_JunyD00tX" class="tag">
                    卷积核
                  </a>
                
                  <a href="https://kaluo-zz.github.io/tag/rT7AxA2VY" class="tag">
                    pytorch
                  </a>
                
              </div>
            
            
              <div class="next-post">
                <div class="next">下一篇</div>
                <a href="https://kaluo-zz.github.io/post/python-jiao-ben-shu-chu-chong-ding-xiang-wen-ti">
                  <h3 class="post-title">
                    Python脚本输出重定向问题
                  </h3>
                </a>
              </div>
            

            

          </div>

        </div>
      </div>
    </div>

    <script src="https://unpkg.com/aos@next/dist/aos.js"></script>

<script type="application/javascript">

AOS.init();

hljs.initHighlightingOnLoad()

var app = new Vue({
  el: '#app',
  data: {
    menuVisible: false,
  },
})

</script>




  </body>
</html>
